{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "a595904e-f1c2-4f6c-8101-01be901881ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "#2.6 概率"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "52866d75-c7e1-4df9-862e-0ef873e0cef2",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "F:\\programData\\conda\\env\\ChatGLM2-6B\\lib\\site-packages\\numpy\\_distributor_init.py:30: UserWarning: loaded more than 1 DLL from .libs:\n",
      "F:\\programData\\conda\\env\\ChatGLM2-6B\\lib\\site-packages\\numpy\\.libs\\libopenblas.FB5AE2TYXYH2IJRDKGDGQ3XBKLKTF43H.gfortran-win_amd64.dll\n",
      "F:\\programData\\conda\\env\\ChatGLM2-6B\\lib\\site-packages\\numpy\\.libs\\libopenblas64__v0.3.23-246-g3d31191b-gcc_10_3_0.dll\n",
      "  warnings.warn(\"loaded more than 1 DLL from .libs:\"\n"
     ]
    }
   ],
   "source": [
    "#2.6.1. 基本概率论\n",
    "#大数定律（law of large numbers）告诉我们：\n",
    "#随着投掷次数的增加，这个估计值会越来越接近真实的潜在概率\n",
    "%matplotlib inline\n",
    "import torch\n",
    "from torch.distributions import multinomial\n",
    "from d2l import torch as d2l"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "3a69a9ea-e7b9-4eb3-a2b2-ee17fdab2f9e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0., 0., 0., 0., 1., 0.])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fair_probs = torch.ones([6]) / 6\n",
    "multinomial.Multinomial(1,fair_probs).sample()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "b46723cf-d02c-48be-b9fc-1ac1cc6662c2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([2., 4., 1., 1., 2., 0.])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "multinomial.Multinomial(10,fair_probs).sample()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "2de12b6f-3cbb-4db5-8a17-6bfa0c943d00",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.1780, 0.1750, 0.1690, 0.1500, 0.1600, 0.1680])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 将结果存储为32位浮点数以进行除法\n",
    "counts = multinomial.Multinomial(1000,fair_probs).sample()\n",
    "counts /1000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "4af1b40f-4fcb-44a1-907a-c325b64e8694",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.3000, 0.3000, 0.1000, 0.0000, 0.2000, 0.1000],\n",
       "        [0.2000, 0.1500, 0.2000, 0.1500, 0.2500, 0.0500],\n",
       "        [0.2000, 0.2000, 0.1667, 0.1333, 0.1667, 0.1333],\n",
       "        ...,\n",
       "        [0.1622, 0.1620, 0.1669, 0.1647, 0.1703, 0.1739],\n",
       "        [0.1619, 0.1625, 0.1665, 0.1645, 0.1707, 0.1737],\n",
       "        [0.1618, 0.1626, 0.1668, 0.1642, 0.1710, 0.1736]])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#进行500组实验，每组抽取10个样本\n",
    "counts = multinomial.Multinomial(10,fair_probs).sample((500,))\n",
    "cum_counts = counts.cumsum(dim=0)\n",
    "estimates = cum_counts/cum_counts.sum(dim=1,keepdims=True)\n",
    "estimates"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "780043ae-e968-4895-8bed-e4b2c0ac0b11",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
